# -*- coding: utf-8 -*-

"""
@Datetime: 2019/3/27
@Author: Zhang Yafei
"""
import warnings

warnings.filterwarnings(action='ignore', category=UserWarning, module='gensim')

from gensim.models.doc2vec import Doc2Vec, TaggedDocument


def read_data():
    """
    数据预处理 准备文档词矩阵
    :return [TaggedDocument(words=['contribut', 'antarctica', 'past', 'futur', 'sea-level', 'rise'], tags=[0]),
             TaggedDocument(words=['evid', 'limit', 'human', 'lifespan'], tags=[1]),
             ...]
    """
    with open('origin_data/titles.txt', 'r', encoding='utf-8') as f:
        abstracts = list(map(lambda line: line.strip().split(), f.readlines()))
    return [TaggedDocument(doc, [i]) for i, doc in enumerate(abstracts)]


class Doc2VecModel(object):
    """
    Doc2Vec模型
    """

    def __init__(self, vec_size=10, alpha=0.025):
        self.model = Doc2Vec(vector_size=vec_size,
                             alpha=alpha,
                             min_alpha=0.00025,
                             min_count=1,
                             dm=1)

    def run(self, documents, max_epochs=100):
        """
        训练模型及结果的保存
        :param documents:
        :param max_epochs:
        :return:
        """
        # 根据文档词矩阵构建词汇表
        self.model.build_vocab(documents)
        # 开始迭代
        for epoch in range(max_epochs + 1):
            print('iteration {0}'.format(epoch))
            self.model.train(documents,
                             total_examples=self.model.corpus_count,
                             epochs=self.model.iter)
            # decrease the learning rate
            self.model.alpha -= 0.0002
            # fix the learning rate, no decay
            self.model.min_alpha = self.model.alpha
        # 模型保存
        self.model.save('results_data/all_model_titles')


if __name__ == '__main__':
    documents = read_data()
    model = Doc2VecModel()
    model.run(documents=documents, max_epochs=100)
